﻿// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Interop;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Proxy;
using Microsoft.Spark.CSharp.Proxy.Ipc;
using Microsoft.Spark.CSharp.Services;

namespace Microsoft.Spark.CSharp.Streaming
{
    /// <summary>
    /// DStream representing the stream of data generated by `mapWithState` operation on a pair DStream.
    /// Additionally, it also gives access to the stream of state snapshots, that is, the state data of all keys after a batch has updated them.
    /// </summary>
    /// <typeparam name="K">Type of the key</typeparam>
    /// <typeparam name="V">Type of the value</typeparam>
    /// <typeparam name="S">Type of the state data</typeparam>
    /// <typeparam name="M">Type of the mapped data</typeparam>
    [Serializable]
    public class MapWithStateDStream<K, V, S, M> : DStream<M>
    {
        internal DStream<Tuple<K, S>> snapshotsDStream;

        internal MapWithStateDStream(DStream<M> mappedDataDStream, DStream<Tuple<K, S>> snapshotsDStream)
            : base(mappedDataDStream.DStreamProxy, mappedDataDStream.streamingContext)
        {
            this.snapshotsDStream = snapshotsDStream;
        }

        /// <summary>
        /// Return a pair DStream where each RDD is the snapshot of the state of all the keys.
        /// </summary>
        public DStream<Tuple<K, S>> StateSnapshots()
        {
            return snapshotsDStream;
        }
    }

    /// <summary>
    /// Class to hold a state instance and the timestamp when the state is updated or created.
    /// No need to explicitly make this class clonable, since the serialization and deserialization in Worker is already a kind of clone mechanism. 
    /// </summary>
    /// <typeparam name="S">Type of the state data</typeparam>
    [Serializable]
    internal class KeyedState<S>
    {
        internal S state;
        internal long ticks;

        internal KeyedState()
        {
            
        }

        internal KeyedState(S state, long ticks)
        {
            this.state = state;
            this.ticks = ticks;
        }
    }

    /// <summary>
    /// Record storing the keyed-state MapWithStateRDD. 
    /// Each record contains a stateMap and a sequence of records returned by the mapping function of MapWithState.
    /// Note: don't need to explicitly make this class clonable, since the serialization and deserialization in Worker is already a kind of clone. 
    /// </summary>
    /// <typeparam name="K">Type of the key</typeparam>
    /// <typeparam name="S">Type of the state data</typeparam>
    /// <typeparam name="M">Type of the mapped data</typeparam>
    [Serializable]
    internal class MapWithStateRDDRecord<K, S, M>
    {
        internal Dictionary<K, KeyedState<S>> stateMap = new Dictionary<K, KeyedState<S>>();
        internal List<M> mappedData = new List<M>();

        public MapWithStateRDDRecord()
        {
        }

        public MapWithStateRDDRecord(long t, IEnumerable<Tuple<K, S>> iter)
        {
            foreach (var p in iter)
            {
                stateMap[p.Item1] = new KeyedState<S>(p.Item2, t);
            }
        }
    }

    /// <summary>
    /// Helper class to update states for a RDD partition.
    /// Reference: https://github.com/apache/spark/blob/master/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
    /// </summary>
    /// <typeparam name="K">Type of the key</typeparam>
    /// <typeparam name="V">Type of the value</typeparam>
    /// <typeparam name="S">Type of the state data</typeparam>
    /// <typeparam name="M">Type of the mapped data</typeparam>
    [Serializable]
    internal class UpdateStateHelper<K, V, S, M>
    {
        [NonSerialized]
        private readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(UpdateStateHelper<K, V, S, M>));

        private readonly Func<K, V, State<S>, M> f;
        private readonly long ticks;
        private readonly bool removeTimedoutData;
        private readonly TimeSpan idleDuration;

        internal UpdateStateHelper(Func<K, V, State<S>, M> f, long ticks, bool removeTimedoutData, TimeSpan idleDuration)
        {
            this.f = f;
            this.ticks = ticks;
            this.removeTimedoutData = removeTimedoutData;
            this.idleDuration = idleDuration;
        }

        internal IEnumerable<dynamic> Execute(int pid, IEnumerable<dynamic> iter)
        {
            var enumerator = iter.GetEnumerator();
            var preStateRddRecord = GetStateRecord(enumerator);
            var stateRddRecord = preStateRddRecord;

            while (enumerator.MoveNext())
            {
                Tuple<K, V> kv = enumerator.Current;
                KeyedState<S> keyedState;
                State<S> wrappedState = stateRddRecord.stateMap.TryGetValue(kv.Item1, out keyedState) ? new State<S>(keyedState.state) : new State<S>(default(S));

                var mappedData = default(M);
                try
                {
                    mappedData = f(kv.Item1, kv.Item2, wrappedState);
                }
                catch (Exception e)
                {
                    logger.LogException(e);
                }
               
                stateRddRecord.mappedData.Add(mappedData);

                if (wrappedState.removed)
                {
                    stateRddRecord.stateMap.Remove(kv.Item1);
                }
                else if (wrappedState.updated || wrappedState.defined)
                {
                    stateRddRecord.stateMap[kv.Item1] = new KeyedState<S>(wrappedState.state, ticks);
                }
            }

            // Get the timed out state records, call the mapping function on each and collect the data returned
            if (removeTimedoutData)
            {
                long timeoutThresholdInTicks = ticks - idleDuration.Ticks;
                var toBeRemovedKeys = new List<K>();
                foreach (KeyValuePair<K, KeyedState<S>> entry in stateRddRecord.stateMap)
                {
                    if (entry.Value.ticks >= timeoutThresholdInTicks) continue;

                    var timingOutstate = new State<S>(entry.Value.state, true);
                    var mappedData = default(M);
                    try
                    {
                        mappedData = f(entry.Key, default(V), timingOutstate);
                    }
                    catch (Exception e)
                    { 
                        logger.LogException(e);
                    }
                        
                    stateRddRecord.mappedData.Add(mappedData);
                    toBeRemovedKeys.Add(entry.Key);
                }

                foreach (var k in toBeRemovedKeys)
                {
                    stateRddRecord.stateMap.Remove(k);
                }
            }

            return new []{stateRddRecord};
        }

        internal MapWithStateRDDRecord<K, S, M> GetStateRecord(IEnumerator<dynamic> enumerator)
        {
            if (enumerator.MoveNext())
            {
                return enumerator.Current;
            }

            throw new InvalidOperationException("MapWithStateRDDRecord is missing.");
        }
    }

    [Serializable]
    internal class MapWithStateHelper<K, V, S, M>
    {
        private static readonly DateTime UnixTimeEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
        private readonly Func<double, RDD<dynamic>, RDD<dynamic>> prevFunc;
        private readonly StateSpec<K, V, S, M> stateSpec;

        internal MapWithStateHelper(Func<double, RDD<dynamic>, RDD<dynamic>> prevF, StateSpec<K, V, S, M> stateSpec)
        {
            prevFunc = prevF;
            this.stateSpec = stateSpec;
        }

        internal RDD<dynamic> Execute(double t, RDD<dynamic> stateRDD, RDD<dynamic> valuesRDD)
        {
            long ticks = UnixTimeEpoch.AddMilliseconds(t).Ticks;

            if (prevFunc != null)
            {
                valuesRDD = prevFunc(t, valuesRDD);
            }

            var values = valuesRDD.ConvertTo<Tuple<K, V>>().PartitionBy(stateSpec.numPartitions);

            if (stateRDD == null)
            {
                if (stateSpec.initialState != null)
                {
                    if (stateSpec.initialState.sparkContext == null)
                    {
                        stateSpec.initialState.sparkContext = valuesRDD.sparkContext;
                    }
                    var partitionedInitialState = stateSpec.initialState.PartitionBy(stateSpec.numPartitions);
                    stateRDD = partitionedInitialState.MapPartitions(new MapWithStateMapPartitionHelper<K, V, S, M>(ticks).Execute, true).ConvertTo<dynamic>();
                }
                else
                {
                    stateRDD = values.PartitionBy(stateSpec.numPartitions).MapPartitions(new MapWithStateMapPartitionHelper<K, V, S, M>(ticks).ExecuteWithoutInitialState, true).ConvertTo<dynamic>();
                }
            }
            
            bool removeTimedoutData = stateSpec.idleDuration.Ticks != 0 && stateRDD.IsCheckpointed;
            stateRDD.partitioner = values.partitioner;
            RDD<dynamic> union = stateRDD.Union(values.ConvertTo<dynamic>());

            return union.MapPartitionsWithIndex(new UpdateStateHelper<K, V, S, M>(stateSpec.mappingFunction, ticks, removeTimedoutData, stateSpec.idleDuration).Execute, true);
        }
    }

    [Serializable]
    internal class MapWithStateMapPartitionHelper<K, V, S, M>
    {
        internal long ticks;
        internal MapWithStateMapPartitionHelper(long ticks)
        {
            this.ticks = ticks;
        }

        internal IEnumerable<MapWithStateRDDRecord<K, S, M>> Execute(IEnumerable<Tuple<K, S>> iter)
        {
            return new[] {new MapWithStateRDDRecord<K, S, M>(ticks, iter)};
        }

        internal IEnumerable<MapWithStateRDDRecord<K, S, M>> ExecuteWithoutInitialState(IEnumerable<Tuple<K, V>> iter)
        {
            return new[] { new MapWithStateRDDRecord<K, S, M>() };
        }
    }

    /// <summary>
    /// Representing all the specifications of the DStream transformation `mapWithState` operation.
    /// </summary>
    /// <typeparam name="K">Type of the key</typeparam>
    /// <typeparam name="V">Type of the value</typeparam>
    /// <typeparam name="S">Type of the state data</typeparam>
    /// <typeparam name="M">Type of the mapped data</typeparam>
    [Serializable]
    public class StateSpec<K, V, S, M>
    {
        internal Func<K, V, State<S>, M> mappingFunction;
        internal int numPartitions;
        internal TimeSpan idleDuration = TimeSpan.FromTicks(0);
        internal RDD<Tuple<K, S>> initialState = null;

        /// <summary>
        /// Create a StateSpec for setting all the specifications of the `mapWithState` operation on a pair DStream.
        /// </summary>
        /// <param name="mappingFunction">The function applied on every data item to manage the associated state and generate the mapped data</param>
        public StateSpec(Func<K, V, State<S>, M> mappingFunction)
        {
            this.mappingFunction = mappingFunction;
        }

        /// <summary>
        /// Set the number of partitions by which the state RDDs generated by `mapWithState` will be partitioned.
        /// Hash partitioning will be used.
        /// </summary>
        /// <param name="numPartitions">The number of partitions</param>
        /// <returns>The new StateSpec object</returns>
        public StateSpec<K, V, S, M> NumPartitions(int numPartitions)
        {
            this.numPartitions = numPartitions;
            return this;
        }

        /// <summary>
        /// Set the duration after which the state of an idle key will be removed. A key and its state is
        /// considered idle if it has not received any data for at least the given duration. The
        /// mapping function will be called one final time on the idle states that are going to be
        /// removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set to `true` in that call.
        /// </summary>
        /// <param name="idleDuration">The idle time of duration</param>
        /// <returns>The new StateSpec object</returns>
        public StateSpec<K, V, S, M> Timeout(TimeSpan idleDuration)
        {
            this.idleDuration = idleDuration;
            return this;
        }

        /// <summary>
        /// Set the RDD containing the initial states that will be used by mapWithState
        /// </summary>
        /// <param name="initialState">The given initial state</param>
        /// <returns>The new StateSpec object</returns>
        public StateSpec<K, V, S, M> InitialState(RDD<Tuple<K, S>> initialState)
        {
            this.initialState = initialState;
            return this;
        }
    }

    /// <summary>
    /// class for getting and updating the state in mapping function used in the `mapWithState` operation
    /// </summary>
    /// <typeparam name="S">Type of the state</typeparam>
    [Serializable]
    public class State<S>
    {
        internal S state = default(S);

        [NonSerialized]
        internal bool defined = false;
        [NonSerialized]
        internal bool timingOut = false; // FIXME: set timingOut to true for those timeouted keys
        [NonSerialized]
        internal bool updated = false;
        [NonSerialized]
        internal bool removed = false;

        internal State(S state, bool timingOut = false)
        {
            this.state = state;
            this.timingOut = timingOut;
            removed = false;
            updated = false;

            if (!timingOut)
            {
                defined = !ReferenceEquals(null, state);
            }
            else
            {
                defined = true;
            }
        }

        /// <summary>
        /// Returns whether the state already exists
        /// </summary>
        /// <returns>true, if the state already exists; otherwise, false.</returns>
        public bool Exists()
        {
            return defined;
        }

        /// <summary>
        /// Gets the state if it exists, otherwise it will throw ArgumentException.
        /// </summary>
        /// <returns>The state</returns>
        /// <exception cref="ArgumentException">ArgumentException if it does not exist.</exception>
        public S Get()
        {
            if (defined)
            {
                return state;
            }
            throw new ArgumentException("State is not set");
        }

        /// <summary>
        /// Updates the state with a new value.
        /// </summary>
        /// <param name="newState">The new state</param>
        /// <exception cref="ArgumentException">ArgumentException if the state already be removed or timing out</exception>
        public void Update(S newState)
        {
            if (removed || timingOut)
            {
                throw new ArgumentException("Cannot update the state that is timing out or has been removed.");
            }
            state = newState;
            defined = true;
            updated = true;
        }

        /// <summary>
        /// Removes the state if it exists.
        /// </summary>
        /// <exception cref="ArgumentException">ArgumentException if the state already be removed or timing out</exception>
        public void Remove()
        {
            if (removed || timingOut)
            {
                throw new ArgumentException("Cannot update the state that is timing out or has already been removed.");
            }
            defined = false;
            updated = false;
            removed = true;
        }

        /// <summary>
        /// Returns whether the state is timing out and going to be removed by the system after the current batch.
        /// </summary>
        /// <returns>true, if it is timing out; otherwise, false.</returns>
        public bool IsTimingOut()
        {
            return timingOut;
        }
    }
}
